"""BBPO(behavior proximal policy optimization)
"""
from critic import ActorCritic
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AutoModelForCausalLM, AutoTokenizer, get_cosine_schedule_with_warmup
import json
from peft import PeftModel
import os
from datetime import datetime
import wandb
from tqdm import tqdm
import argparse
from inference import inference
from eval import eval_file
from sft_instruct import seed_everything
import argparse
import random


class GSM8kChatDataset(Dataset):
    def __init__(self, data, tokenizer):
        self.data = data
        self.tokenizer = tokenizer
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        question = self.data[idx]["question"]

        instruction = self.tokenizer.apply_chat_template(
            [
                {"role": "system", "content": "Please solve the following problem step by step."},
                {"role": "user", "content": question},
            ],
            tokenize = False,
            add_generation_prompt=True,
        )
        return instruction


class BPPODataset(Dataset):
    def __init__(self, data, tokenizer, max_length=1024):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length


    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        example = self.data[idx]
        question = example['question']
        answer = example['answer']

        instruction = self.tokenizer.apply_chat_template(
            [
                {"role": "system", "content": "Please solve the following problem step by step."},
                {"role": "user", "content": question},
            ],
            tokenize = False,
            add_generation_prompt=True,
        )
        instruction = self.tokenizer(instruction, add_special_tokens=False)

        response = self.tokenizer(answer,  add_special_tokens=False)

        # Combine the input IDs from the instruction and response, and append a padding token.
        input_ids = instruction["input_ids"] + response["input_ids"] + [self.tokenizer.pad_token_id]
        # print(input_ids)
        attention_mask = (
        instruction["attention_mask"] + response["attention_mask"] + [1]
        )

        # Create labels for the model. Mask the instruction part with -100 (ignored during loss calculation).
        labels = [self.tokenizer.pad_token_id] * len(instruction["input_ids"]) + response["input_ids"] + [self.tokenizer.pad_token_id]

        start_idx = len(instruction["input_ids"])
        last_idx = len(input_ids) - 1

        pad_left_rewards = [-100] * (start_idx) + example['rewards'] + [0]
        
        assert len(input_ids) == len(labels)
        assert len(input_ids) == len(pad_left_rewards)
        assert len(input_ids) == len(attention_mask)
        return {
            'input_ids': input_ids,
            'labels': labels,
            'start_idx': start_idx,
            'last_idx': last_idx,
            'rewards': example['rewards'],
            'pad_left_rewards': pad_left_rewards,
            'attention_mask': attention_mask,
        }

def collate_fn(batch, tokenizer):
    input_ids = [torch.tensor(x["input_ids"], dtype=torch.long) for x in batch]
    labels = [torch.tensor(x["labels"], dtype=torch.long) for x in batch]
    attention_mask = [torch.tensor(x["attention_mask"], dtype=torch.long) for x in batch]

    start_idx = [x["start_idx"] for x in batch]
    last_idx = [x["last_idx"] for x in batch]
    padded_left_rewards = [torch.tensor(x["pad_left_rewards"], dtype=torch.long) for x in batch]
    # Pad sequences
    padded_inputs = torch.nn.utils.rnn.pad_sequence(
        input_ids, batch_first=True, padding_value=tokenizer.pad_token_id
    )
    padded_labels = torch.nn.utils.rnn.pad_sequence(
        labels, batch_first=True, padding_value=tokenizer.pad_token_id
    )
    padded_rewards = torch.nn.utils.rnn.pad_sequence(
        padded_left_rewards, batch_first=True, padding_value=-100
    )
    padded_attention_mask = torch.nn.utils.rnn.pad_sequence(
        attention_mask, batch_first=True, padding_value=0
    )
    # attention_mask = (padded_inputs != tokenizer.pad_token_id).int()
    reward_mask = (padded_rewards != -100).int()
    return {
        "input_ids": padded_inputs,
        "attention_mask": padded_attention_mask,
        "labels": padded_labels,
        "rewards": padded_rewards,
        'reward_mask': reward_mask,
        'start_idx': start_idx,
        'last_idx': last_idx,
    }

def collate_fn_eval(batch, tokenizer):
    return tokenizer(
        batch,
        padding=True,
        truncation=True,
        max_length=1024,  # 控制输入长度
        return_tensors="pt"
    )

class BPPO:
    def __init__(self,
        # actor_critic: ActorCritic,
        coloned_model: AutoModelForCausalLM,
        config: dict,
        dataloader: DataLoader,
        device: str,
        tokenizer: AutoTokenizer,
        tag: str,
        flag: str,
        base_model_name: str,
    ):
        self.actor_critic = ActorCritic(coloned_model)
        self.config = config
        self.dataloader = dataloader
        self.device = device
        self.tag = tag
        self.flag = flag
        self.tokenizer = tokenizer
        self.base_model_name = base_model_name

    def pretrain_sarsa_duel(self):
        # Set the LM and other heads to non-trainable
        self.actor_critic.set_lm_non_trainable()
        self.actor_critic.set_head_trainable('value')
        self.actor_critic.set_head_trainable('advantage')
        
        config = self.config['pretrain_value']
        num_epochs = config['num_epochs']
        batch_size = config['batch_size']
        gamma = config['gamma']
        use_cosine_lr = config['use_cosine_lr']
        save_dir = config['save_dir']
        lr = config['lr']
        # Define optimizers for the Value and Q-function heads
        optimizer = torch.optim.Adam(
            [
                {'params': self.actor_critic.value_head.parameters()},
                {'params': self.actor_critic.advantage_head.parameters()}
            ],
            lr=lr
        )
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs) if use_cosine_lr else None

        time = datetime.now().strftime("%m%d-%H%M%S")
        wandb.init(project='bppo-value', entity='Your wandb entity', name=f'bppo-duel-{self.tag}-lr{lr}-bs{batch_size}-gamma{gamma}-seed{seed}-{time}')
        # Training loop
        step = 0
        for epoch in tqdm(range(num_epochs)):
            for batch in tqdm(self.dataloader):
                input_ids = batch['input_ids'].to(self.device)
                attention_mask = batch['attention_mask'].to(self.device)
                rewards = batch['rewards'].to(self.device)
                labels = batch['labels'].to(self.device)
                # get the non -100 rewards
                reward_mask = batch['reward_mask'].to(self.device)
                # Forward pass through the actor-critic model
                outputs = self.actor_critic(input_ids=input_ids, attention_mask=attention_mask)
                q_values = outputs['Q_values'] # [B, T, V]
                # print(labels)
                chosen_q_values = q_values.gather(dim=-1, index=labels.unsqueeze(-1)).squeeze(-1)  # [B, T]
                # print('chosen_q_values', chosen_q_values)
                # Shift q-values and pad with zeros
                next_q_values = torch.cat([chosen_q_values[:, 1:], torch.zeros((chosen_q_values.size(0), 1), device=self.device)], dim=1) # [B, T]
                
                # Create done mask from reward mask
                done_mask = torch.cat([reward_mask[:, 1:], torch.zeros((reward_mask.size(0), 1), device=self.device)], dim=1)

                # print('next_q_values', next_q_values.shape)
                # print('reward_mask', reward_mask.shape)
                # print('rewards', rewards.shape)
                # print('done_mask', done_mask.shape)
                td_target = rewards + gamma * next_q_values * done_mask
                # print('td_target', td_target.shape)
                td_errors = (td_target - chosen_q_values).pow(2)
                masked_td_errors = td_errors * reward_mask
                sarsa_loss = masked_td_errors.sum() / (reward_mask.sum() + 1e-8)
                # values = outputs.values
            
                optimizer.zero_grad()
                sarsa_loss.backward()
                optimizer.step()
                if scheduler is not None:
                    scheduler.step()
                
                # Logging (optional)
                if self.config.get('verbose', False):
                    print(f"Epoch {epoch + 1}, SARSA Loss: {sarsa_loss.item()}")
                wandb.log({'sarsa_loss': sarsa_loss.item()}, step=step)
                step += 1
                # break
        # SAVE MODEL
        save_path = os.path.join(save_dir, 'bppo-Q-value-pretrained')
        self.actor_critic.save_value_pretrained(save_path)
        wandb.finish()

    def pretrain_value(self):
        self.actor_critic.set_lm_non_trainable()
        self.actor_critic.set_head_trainable('value')
        # self.actor_critic.set_head_trainable('advantage')
        
        config = self.config['pretrain_value']
        num_epochs = config['num_epochs']
        batch_size = config['batch_size']
        gamma = config['gamma']
        use_cosine_lr = config['use_cosine_lr']
        save_dir = config['save_dir']
        lr = config['lr']
        # use_scheduler = config['use_scheduler']
        # Define optimizers for the Value and Q-function heads
        optimizer = torch.optim.Adam(
            [
                {'params': self.actor_critic.value_head.parameters()},
            ],
            lr=lr
        )
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs) if use_scheduler else None

        time = datetime.now().strftime("%m%d-%H%M%S")
        wandb.init(project='bppo-value', entity='Your wandb entity', name=f'bppo-value-{self.flag}-{self.tag}-lr{lr}-bs{batch_size}-gamma{gamma}-{time}')
        # Training loop
        step = 0
        for epoch in tqdm(range(num_epochs)):
            for batch in tqdm(self.dataloader):
                input_ids = batch['input_ids'].to(self.device)
                attention_mask = batch['attention_mask'].to(self.device)
                rewards = batch['rewards'].to(self.device)
                labels = batch['labels'].to(self.device)
                # get the non -100 rewards
                reward_mask = batch['reward_mask'].to(self.device)

                # Forward pass through the actor-critic model
                outputs = self.actor_critic(input_ids=input_ids, attention_mask=attention_mask)
                value = outputs['value']

                # Shift q-values and pad with zeros
                next_value = torch.cat([value[:, 1:], torch.zeros((value.size(0), 1), device=self.device)], dim=1) # [B, T]
                
                # Create done mask from reward mask
                done_mask = torch.cat([reward_mask[:, 1:], torch.zeros((reward_mask.size(0), 1), device=self.device)], dim=1)

                # print('next_q_values', next_q_values.shape)
                # print('reward_mask', reward_mask.shape)
                # print('rewards', rewards.shape)
                td_target = rewards + gamma * next_value * done_mask
                # print('td_target', td_target.shape)

                td_errors = (td_target - value).pow(2)
                masked_td_errors = td_errors * reward_mask
                sarsa_loss = masked_td_errors.sum() / (reward_mask.sum() + 1e-8)
                # values = outputs.values
            
                optimizer.zero_grad()
                mse_loss.backward()
                optimizer.step()
                if scheduler is not None:
                    scheduler.step()
                
                # Logging (optional)
                if self.config.get('verbose', False):
                    print(f"Epoch {epoch + 1}, MSE Loss: {mse_loss.item()}")
                wandb.log({'mse_loss': mse_loss.item()}, step=step)
                step += 1
                # break
        # SAVE MODEL
        save_path = os.path.join(save_dir, 'bppo-value')
        self.actor_critic.save_value_pretrained(save_path)
        wandb.finish()


    # def pretrain_V(self, num_epochs=1, batch_size=16):
    #     pass
    def load_value_pretrained(self, save_path):
        self.actor_critic.load_value_pretrained(save_path)

    def get_advantage(self, outputs, config, batch):
        if config['type'] == 'sarsa':
            advantage_values = outputs['advantage_values']
            labels = batch['labels'].to(self.device)
            advantage_values = advantage_values.gather(dim=-1, index=labels.unsqueeze(-1)).squeeze(-1)  # [B, T]
            return advantage_values

        elif config['type'] == 'gae':
            value = outputs['value']  # [batch_size, seq_len]
            rewards = batch['rewards']  # [batch_size, seq_len]
            rewards_mask = batch['reward_mask']  # [batch_size, seq_len] # only the 1 value is valid
            gamma = config['gamma'] if config['gamma'] is not None else 0.99  # discount factor
            gae_lambda = config['gae_lambda'] if config['gae_lambda'] is not None else 0.98  # gae lambda
            start_indices = batch['start_idx']
            last_indices = batch['last_idx']

            advantages = torch.zeros_like(rewards).to(self.device)
            for batch_idx in range(len(rewards)):
                reward = rewards[batch_idx]  # [seq_len]
                value_batch = value[batch_idx]  # [seq_len]
                mask = rewards_mask[batch_idx]  # [seq_len]
                start_idx = start_indices[batch_idx]
                last_idx = last_indices[batch_idx]
                # initialize advantage and gae
                gae = 0
                # backward calculate gae
                for t in reversed(range(start_idx, last_idx+1)):
                    if t == last_idx:
                        next_value = 0  # the next_value of the last time step is 0
                    else:
                        next_value = value_batch[t + 1]

                    # calculate td error
                    delta = reward[t] + gamma * next_value - value_batch[t]

                    # update gae
                    gae = delta + gamma * gae_lambda * gae
                    # only keep the valid position advantage
                    # if mask[t] == 1:
                    assert mask[t] == 1
                    advantages[batch_idx, t] = gae  
                                 
            return advantages

    def get_logits(self, model, data_loader, target_steps):
        old_logits_list = []
        batch_list = []
        iter_data_loader = iter(data_loader)
        model.eval()
        with torch.no_grad():
            # target_steps = min(num_steps_per_iter, len(self.dataloader))
            # for old_batch in tqdm(self.dataloader, desc="Precomputing old logits"):
            for i, old_batch in enumerate(tqdm(data_loader, desc="Precomputing old logits")):
                if i >= target_steps:
                    break
                else:
                    # print(old_batch['input_ids'][0][:-10])
                    # print('------------')
                    input_ids = old_batch['input_ids'].to(self.device)
                    attention_mask = old_batch['attention_mask'].to(self.device)
                    labels = old_batch['labels'].to(self.device)
                    old_logits_ = self.actor_critic(input_ids=input_ids, attention_mask=attention_mask)['logits']
                    # find the chosen action
                    old_log_probs_ = torch.log_softmax(old_logits_, dim=-1) # [B, T, V]
                    old_log_likelihood = old_log_probs_.gather(dim=-1, index=labels.unsqueeze(-1)).squeeze(-1)  # [B, T]
                    old_logits_list.append(old_log_likelihood.detach().cpu().numpy())
                    batch_list.append(old_batch)
                    torch.cuda.empty_cache()
                    del input_ids, attention_mask, labels, old_logits_, old_log_probs_, old_log_likelihood
        model.train()
        return old_logits_list, batch_list


    def iterative_train_policy(self):
        seed = self.config['seed']
        config = self.config['iterative_train_policy']
        num_epochs = config['num_epochs']
        num_iter = config['num_iter']
        lr = config['lr']
        save_dir = config['save_dir']
        gamma = config['gamma']
        time = datetime.now().strftime("%m%d-%H%M%S")
        use_cosine_lr = config['use_cosine_lr']
        clip_ratio = config['clip_ratio']
        gradient_accumulation_steps = config['gradient_accumulation_steps']
        type_ = config['type']
        start_accuracy = config['start_accuracy']
        # store old logits
        # self
        num_batchs_per_iter = config['num_batchs_per_iter']
        # steps_per_batch = config['steps_per_batch']

        self.actor_critic.set_head_non_trainable('value')
        self.actor_critic.set_head_non_trainable('advantage')

        # self.actor_critic.lm.set_lora_trainable_parameters()
        # print the trainable parameters
        self.actor_critic.lm.print_trainable_parameters()

        wandb.init(project=f'bppo-policy-{self.flag}', entity='xjxyys', name=f'bppo-{self.tag}-start-acc-{start_accuracy}-type-{type_}-lr{lr}-epoch-{num_epochs}-iter-{num_iter}-clip-{clip_ratio}-gradient_accumul-{gradient_accumulation_steps}-step_per_iter-{num_batchs_per_iter}-seed-{seed}-gamma{gamma}-{time}')

        # define the optimizer
        optimizer = torch.optim.Adam(
            [
                {'params': self.actor_critic.lm.parameters()}
            ],
            lr=lr
        )
        optimizer.zero_grad()
        target_steps = min(num_batchs_per_iter, len(self.dataloader))

        warmup_steps = int(target_steps * 0.03)
        # use_scheduler = config['use_scheduler']
        scheduler = get_cosine_schedule_with_warmup(
            optimizer, num_warmup_steps=warmup_steps, num_training_steps=target_steps
        ) if use_cosine_lr else None
        # change to not shuffle
        # self.dataloader.shuffle = False        
        accuracy_dict = {f'Iter{i}': 0 for i in range(num_iter)}
        # save the best model
        save_path = os.path.join(save_dir, f'bppo-best-policy-T{time}')
        # with torch.no_grad():
        #     old_logits_list = []
        #     target_steps = min(num_steps_per_iter, len(self.dataloader))
        #     # for old_batch in tqdm(self.dataloader, desc="Precomputing old logits"):
        #     for i in tqdm(range(target_steps), desc="Precomputing old logits"):
        #         old_batch = next(iter(self.dataloader))
        #         step += 1
        #         input_ids = old_batch['input_ids'].to(self.device)
        #         attention_mask = old_batch['attention_mask'].to(self.device)
        #         labels = old_batch['labels'].to(self.device)
        #         old_logits_ = self.actor_critic(input_ids=input_ids, attention_mask=attention_mask)['logits']
        #         # find the chosen action
        #         old_log_probs_ = torch.log_softmax(old_logits_, dim=-1) # [B, T, V]
        #         old_log_likelihood = old_log_probs_.gather(dim=-1, index=labels.unsqueeze(-1)).squeeze(-1)  # [B, T]
        #         old_logits_list.append(old_log_likelihood.detach().cpu().numpy())
        #         if step % num_steps_per_iter == 0:
        #             break
        # target_steps = min(num_steps_per_iter, len(self.dataloader))
        step = 0
        # current_objective = 0
        self.actor_critic.save_policy_pretrained(save_path)

        best_accuracy = start_accuracy
        print('start accuracy:', best_accuracy)
        for iter_idx in tqdm(range(num_iter), desc="Iterative training"):
            # for epoch in tqdm(range(num_epochs)):  
            if iter_idx == 0:
                pass
            else:
                # load the best model
                torch.cuda.empty_cache()
                self.actor_critic.load_policy_pretrained(save_path)
                self.actor_critic.lm.to(self.device)
                optimizer = torch.optim.Adam(
                    [
                        {'params': self.actor_critic.lm.parameters()}
                    ],
                    lr=lr
                )
                scheduler = get_cosine_schedule_with_warmup(
                    optimizer, num_warmup_steps=warmup_steps, num_training_steps=target_steps
                )
                optimizer.zero_grad()
                    
            current_objective_list = []
            old_logits_list, batch_list = self.get_logits(self.actor_critic, self.dataloader, target_steps*gradient_accumulation_steps)
            # for batch in tqdm(self.dataloader):
            # steps_per_batch = config['steps_per_batch']
            for i in tqdm(range(target_steps), desc="Training policy"):
                # print(batch_list[i])
                # for j in range(steps_per_batch):
                for j in range(gradient_accumulation_steps):
                    sample_idx = i * gradient_accumulation_steps + j
                    batch = batch_list[sample_idx]

                    input_ids = batch['input_ids'].to(self.device)
                    attention_mask = batch['attention_mask'].to(self.device)
                    rewards = batch['rewards'].to(self.device)
                    labels = batch['labels'].to(self.device)
                    reward_mask = batch['reward_mask'].to(self.device)

                    outputs = self.actor_critic(input_ids=input_ids, attention_mask=attention_mask)
                    logits = outputs['logits']
        
                    log_probs = torch.log_softmax(logits, dim=-1) # [B, T, V]
                    log_likelihood = log_probs.gather(dim=-1, index=labels.unsqueeze(-1)).squeeze(-1)  # [B, T]
                    # get adavantage
                    advantage = self.get_advantage(outputs, config, batch)
                    old_log_likelihood = torch.tensor(old_logits_list[sample_idx]).to(self.device)
                    # get ppo loss
                    ratio = torch.exp(log_likelihood - old_log_likelihood)  # [B, T]
                    curr_obj = ratio * advantage * reward_mask
                    curr_obj_clipped = torch.clamp(ratio, 1-clip_ratio, 1+clip_ratio) * advantage * reward_mask
                    ppo_loss = -torch.min(curr_obj, curr_obj_clipped).sum() / (reward_mask.sum() + 1e-8)

                    ppo_loss = ppo_loss / gradient_accumulation_steps # normalize the loss
                    # update the model
                    ppo_loss.backward()
                    if j == gradient_accumulation_steps - 1:
                        optimizer.step()
                        optimizer.zero_grad()
                        if config['use_cosine_lr']:
                            scheduler.step()
                        step += 1
                        wandb.log({'ppo_loss': ppo_loss.item()*gradient_accumulation_steps}, step=step)
                        current_objective_list.append(ppo_loss.item() * gradient_accumulation_steps)

                    # Free up memory
                    torch.cuda.empty_cache()
                    del input_ids, attention_mask, rewards, labels, reward_mask, outputs, logits, log_probs, log_likelihood, advantage, old_log_likelihood, ratio, curr_obj, curr_obj_clipped
                # if step % num_steps_per_iter == 0:
                #     break
            # print(f'current_objective_list: {current_objective_list}')
            # current_objective_mean = sum(current_objective_list) / len(current_objective_list)
            # current_objective_last = current_objective_list[-1]
            # current_objective_list = []
            eval_path = os.path.join(save_path, f'inference_results-Iter{iter_idx}.jsonl')
            scores = self.eval_model(self.actor_critic.lm, self.base_model_name, eval_path)
            # scores = [0]
            accuracy = sum(scores) / len(scores)
            accuracy_dict[f'Iter{iter_idx}'] = accuracy
            print(f'accuracy: {accuracy}')
            wandb.log({'accuracy': accuracy}, step=step)


            # if current_objective_last < 0:# We want to minimize the objective
            if accuracy >= best_accuracy:
                # print(f'current_objective_last: {current_objective_last}')
                print('save the best model')
                best_accuracy = accuracy
                self.actor_critic.save_policy_pretrained(save_path)
                # if iter_idx != num_iter - 1:
                #     # recompute the logits
                #     old_logits_list, batch_list = self.get_logits(self.actor_critic, self.dataloader, target_steps)


        with open(os.path.join(save_path, 'accuracy_list.json'), 'w') as f:
            json.dump(accuracy_dict, f)

        return save_path

    def eval_model(self, model, base_model_name, path=''):
        data_path = './gsm8k/test.jsonl'
            # load data
        with open(data_path, 'r') as f:
            data = f.readlines()
            data = [json.loads(d) for d in data]
        tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True, padding_side='left')
       
        dataset = GSM8kChatDataset(data, tokenizer)

        dataloader = DataLoader(dataset, batch_size=16, collate_fn=lambda x: collate_fn_eval(x, tokenizer))

        device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        model.to(device)
        model.eval()
        
        inference_results = []


        for batch in tqdm(dataloader):
            inputs = {k: v.to(device) for k, v in batch.items()}
            generation_config = {
                "max_new_tokens": 2048,          # 控制生成长度
                "do_sample": False,
                "pad_token_id": tokenizer.eos_token_id,
                "use_cache": True               # 启用KV缓存
            }
            with torch.no_grad():
                # print(inputs)
                outputs = model.generate(**inputs, **generation_config)
                batch_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)
                inference_results.extend(batch_texts)

        with open(f'{path}', 'w') as f:
            for text, item in zip(inference_results, data):
                item["generated"] = text
                f.write(json.dumps(item) + '\n')

        scores = eval_file(f'{path}')

        return scores



def load_pretrained_model(adapter_path, base_model_name, is_trainable=False):
    base_model = AutoModelForCausalLM.from_pretrained(base_model_name, device_map="auto", trust_remote_code=True)
    tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token  # For some models like Llama/Mistral
    if adapter_path is not None:
        # use peft
        model = PeftModel.from_pretrained(base_model, adapter_path, is_trainable=is_trainable)
    else:
        model = base_model
    return model, tokenizer

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_path', type=str, default='./math-shepherd/imperfect_data_sparse_threshold_5_new.jsonl')
    parser.add_argument('--base_model_name', type=str, default='Qwen/Qwen2.5-1.5B-Instruct')
    parser.add_argument('--adapter_path', type=str, default='')
    parser.add_argument('--seed', type=int, default=202503)
    parser.add_argument('--num_iter', type=int, default=5)
    parser.add_argument('--num_batchs_per_iter', type=int, default=200)
    parser.add_argument('--steps_per_batch', type=int, default=5)
    parser.add_argument('--batch_size', type=int, default=4)
    parser.add_argument('--clip_ratio', type=float, default=0.1)
    parser.add_argument('--lr', type=float, default=5e-7)
    parser.add_argument('--flag', type=str, default='')
    parser.add_argument('--gradient_accumulation_steps', type=int, default=16)
    parser.add_argument('--type', type=str, default='sarsa')
    parser.add_argument('--mix', type=float, default=0)
    parser.add_argument('--start_accuracy', type=float, default=0)
    
    args = parser.parse_args()

    
    base_model_name = args.base_model_name
    adapter_path = args.adapter_path
    data_path = args.data_path
    num_iter = args.num_iter
    num_batchs_per_iter = args.num_batchs_per_iter
    steps_per_batch = args.steps_per_batch
    batch_size = args.batch_size
    clip_ratio = args.clip_ratio
    lr = args.lr
    flag = args.flag
    seed = args.seed
    # data_path = './math-shepherd/expert_data_sparse.jsonl'
    # data_path = './math-shepherd/expert_data_sparse_threshold_5_new.jsonl'
    # data_path = './math-shepherd/imperfect_data_sparse_threshold_5_new.jsonl'
    # data_path = './math-shepherd/expert_data_sparse_threshold_2_new.jsonl'

    # base_model_name = 'Qwen/Qwen2.5-1.5B-Instruct'
    # adapter_path = './models/Qwen/Qwen2.5-1.5B-Instruct-template-LoRA-seed2025-SFT-cos-b4-r8-alp32-lr5e-07-q-k-v-o-wte-lm_head-expert_threshold_5-0320-235016/best'

    # adapter_path = './models/Qwen/Qwen2.5-1.5B-Instruct-template-LoRA-seed2025-SFT-cos-b4-r8-alp32-lr5e-07-q-k-v-o-wte-lm_head-expert_threshold_5_new-0321-174158/best'

    ### new
    # adapter_path = './models/Qwen/Qwen2.5-1.5B-Instruct-template-LoRA-seed2026-SFT-cos-b4-r8-alp32-lr5e-07-q-k-v-o-wte-lm_head-expert_threshold_5_new-0322-004813/best'
    
    # adapter_path = './models/Qwen/Qwen2.5-1.5B-Instruct-template-LoRA-seed2026-SFT-cos-b4-r8-alp32-lr5e-07-q-k-v-o-wte-lm_head-imperfect_threshold_5_new-0322-184215/last'

    tag = adapter_path.split('/')[-2].split('_')[-1]

    lm, tokenizer = load_pretrained_model(adapter_path, base_model_name, is_trainable=True)

    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    lm.to(device)

    if args.mix == 0:
        with open(data_path, 'r') as f:
            data = f.readlines()
            data = [json.loads(d) for d in data]

    else:
        # expert_data_path = './math-shepherd/expert_data_sparse_threshold_5_new.jsonl'
        # imperfect_data_path = './math-shepherd/imperfect_data_sparse_threshold_5_new.jsonl'
        expert_data_path = './math-shepherd/expert_data_50000.jsonl'
        imperfect_data_path = './math-shepherd/imperfect_data_50000.jsonl'
        with open(expert_data_path, 'r') as f:
            expert_data = f.readlines()
            expert_data = [json.loads(d) for d in expert_data]
        with open(imperfect_data_path, 'r') as f:
            imperfect_data = f.readlines()
            imperfect_data = [json.loads(d) for d in imperfect_data]
            expert_data_num = int(len(expert_data) * args.mix)
            imperfect_data_num = int(len(imperfect_data) * (1 - args.mix))
            data = expert_data[:expert_data_num] + imperfect_data[:imperfect_data_num]

    dataset = BPPODataset(data, tokenizer)

    print('Total data length:', len(dataset))
    config = {
        'seed': seed,
        'pretrain_value':{
            'num_epochs': 1,
            'batch_size': 8,
            'gamma': 0.99,
            'use_cosine_lr': False,
            'save_dir': os.path.join(os.path.dirname(adapter_path)),
            'lr': 1e-5,
        },
        'iterative_train_policy':{
            'num_epochs': 1,
            'num_iter': args.num_iter,
            'gamma': 0.99,
            # 'steps_per_batch': args.steps_per_batch,
            'num_batchs_per_iter': args.num_batchs_per_iter,
            'batch_size': args.batch_size,
            'clip_ratio': args.clip_ratio,
            'save_dir': os.path.join(os.path.dirname(adapter_path)),
            'lr': args.lr,
            'gae_lambda': 0.98,
            'type': args.type,
            'start_accuracy': args.start_accuracy,
            # 'type': 'gae',
            'use_cosine_lr': True,
            'gradient_accumulation_steps': args.gradient_accumulation_steps,
        }
    }

    seed_everything(seed)
    
    dataloader = DataLoader(dataset, batch_size=config['pretrain_value']['batch_size'], shuffle=True, collate_fn=lambda x: collate_fn(x, tokenizer))

    bppo = BPPO(lm, config, dataloader, device, tokenizer, tag=tag, flag=args.flag, base_model_name=base_model_name)
    save_dir = os.path.join(os.path.dirname(adapter_path))
    save_path = os.path.join(save_dir, 'bppo-Q-value-pretrained')
    if not os.path.exists(save_path):
        bppo.pretrain_sarsa_duel()
        bppo.actor_critic.load_policy_pretrained(adapter_path)
        bppo.actor_critic.lm.to(device)

    bppo.load_value_pretrained(save_path)
    dataloader = DataLoader(dataset, batch_size=config['iterative_train_policy']['batch_size'], shuffle=True, collate_fn=lambda x: collate_fn(x, tokenizer))
    bppo.dataloader = dataloader

    save_path = bppo.iterative_train_policy()

    # inference(model_path=save_path, base_model_name=base_model_name, data_num=None, use_template=True)
    # batch = next(iter(dataloader))